[LTX-2] Run Gemma-3 Text Encoder natively in JAX via TorchAX#398
[LTX-2] Run Gemma-3 Text Encoder natively in JAX via TorchAX#398
Conversation
252c34e to
a449a5c
Compare
a449a5c to
d681d61
Compare
| text_input_ids = jnp.array(text_inputs.input_ids) | ||
| prompt_attention_mask = jnp.array(text_inputs.attention_mask) | ||
|
|
||
| # Distribute the batch dimension across available TPUs to prevent Softmax OOM |
There was a problem hiding this comment.
Have we tested in trillium? Wondering since we have less HBM in trillium, will it cause OOM issues? If so we might have consider tp
| ) | ||
| text_encoder.eval() | ||
|
|
||
| with default_env(): |
There was a problem hiding this comment.
Should we have a config param for users to determine whether to use text encoder on cpu vs tpu? Might be useful when dealing with older chips with lower HBM
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR successfully integrates TorchAX for the LTX-2 pipeline's text encoder, bringing significant performance improvements and memory optimizations on TPU. The transition from eager PyTorch to JAX-native execution is well-implemented, and the additional sharding constraints for both the text encoder and VAE are effective strategies for preventing OOM crashes.
🔍 General Feedback
- TorchAX Integration: The use of
TorchaxGemma3TextEncoderand the manual batch sharding logic is a great addition for efficiency. - Memory Management: The conditional sharding and slicing disabling in the VAE decoding loop correctly addresses HBM issues for larger batches.
- Distributed Performance: One critical observation is the explicit un-sharding of text encoder hidden states to a single device, which should be avoided to ensure optimal performance in multi-host environments.
- Code Cleanliness: Small refactors to use
getattrinstead of broadtry/exceptblocks will improve maintainability.
| except Exception: # pylint: disable=broad-exception-caught | ||
| replicate_vae = False | ||
| if replicate_vae: | ||
| sharding = NamedSharding(mesh, P()) |
There was a problem hiding this comment.
🟢 Using getattr is cleaner than a try/except block for checking optional config parameters.
| except Exception: # pylint: disable=broad-exception-caught | |
| replicate_vae = False | |
| if replicate_vae: | |
| sharding = NamedSharding(mesh, P()) | |
| replicate_vae = getattr(config, "replicate_vae", False) |
| except Exception: # pylint: disable=broad-exception-caught | ||
| replicate_vae = False | ||
| if replicate_vae: | ||
| sharding = NamedSharding(mesh, P()) |
There was a problem hiding this comment.
🟢 Same as above, getattr is preferred for optional configuration attributes.
| except Exception: # pylint: disable=broad-exception-caught | |
| replicate_vae = False | |
| if replicate_vae: | |
| sharding = NamedSharding(mesh, P()) | |
| replicate_vae = getattr(config, "replicate_vae", False) |
| state = jax.device_put(state, jax.devices()[0]) | ||
| prompt_embeds_list.append(state.astype(jnp.bfloat16)) | ||
|
|
||
| prompt_embeds = prompt_embeds_list |
There was a problem hiding this comment.
🟠 This call explicitly un-shards the hidden states to a single device (global device 0). This is detrimental in multi-host setups as it forces cross-host communication and can cause a memory bottleneck on the first TPU. Since the hidden states are already JAX arrays computed on the TPU mesh (and will be correctly resharded to the transformer's mesh at line 1367), this step should be removed to maintain efficient distributed execution.
| state = jax.device_put(state, jax.devices()[0]) | |
| prompt_embeds_list.append(state.astype(jnp.bfloat16)) | |
| prompt_embeds = prompt_embeds_list | |
| for state in text_encoder_hidden_states: | |
| prompt_embeds_list.append(state.astype(jnp.bfloat16)) |
| for i in range(len(timesteps_jax)): | ||
| t = timesteps_jax[i] | ||
|
|
||
| for _, t in enumerate(timesteps_jax): |
There was a problem hiding this comment.
🟢 The index _ is not used. You can simplify this to a direct iteration.
| for _, t in enumerate(timesteps_jax): | |
| for t in timesteps_jax: |
| return inner_mask | ||
|
|
||
|
|
||
| transformers.masking_utils.sliding_window_overlay = _patched_sliding_window_overlay |
There was a problem hiding this comment.
🟡 Global monkeypatching of transformers.masking_utils can have unintended side effects if other models in the same process rely on the original behavior. While this is a necessary workaround for TorchAX + Gemma-3, consider documenting the sequence length assumption more explicitly or ensuring this patch doesn't break other potential future Gemma-based models in the same environment.
Description
This PR transitions the LTX-2 pipeline's text encoding process to utilize TorchAX, bridging the Gemma-3 model natively into JAX and significantly optimizing memory usage to prevent TPU out-of-memory errors. Minor PyLint warnings across the pipeline were also resolved during the refactor.
Key changes include:
TorchaxGemma3TextEncoder. TPU sharding is now manually distributed across the batch dimension viajax.device_putto prevent Softmax OOM crashes.batch_size > 2, HBM crashes during decoding are avoided.Benchmarks
Performance comparison demonstrating latency improvements from TorchAX integration.